-
Notifications
You must be signed in to change notification settings - Fork 2
Fix bipermutations in contract
#75
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #75 +/- ##
==========================================
- Coverage 94.90% 94.16% -0.74%
==========================================
Files 14 14
Lines 451 463 +12
==========================================
+ Hits 428 436 +8
- Misses 23 27 +4
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
# default: if no bipartion is specified, all axes to domain | ||
invbiperm(perm, ::Any) = invbiperm(perm, Val(0)) | ||
invbiperm(perm, t::Tuple{Tuple,Tuple}) = invbiperm(perm, tuplemortar(t)) | ||
invbiperm(perm, t::AbstractBlockTuple{2}) = invbiperm(perm, Val(first(blocklength(t)))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
invbiperm(perm, t::AbstractBlockTuple{2}) = invbiperm(perm, Val(first(blocklength(t)))) | |
invbiperm(perm, t::AbstractBlockTuple{2}) = invbiperm(perm, Val(first(blocklengths(t)))) |
using BlockArrays: blocklengths | ||
|
||
# default: if no bipartion is specified, all axes to domain | ||
invbiperm(perm, ::Any) = invbiperm(perm, Val(0)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a bit strange to me that it just allows anything as the second argument. Maybe this should be invbiperm(perm)
? Is this used anywhere right now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking at the rest of the code, I see that it is being used in calls like biperm_a12_to_dest = invbiperm(biperm_dest_to_a12, axes(a_dest))
, where axes(a_dest)
might output a blocked tuple or a flat tuple.
I think the invbiperm
function is trying to do too much and therefore makes the code harder to understand. Instead, maybe we could introduce new functions biperm
and length_codomain
:
function biperm(perm, blocklength1::Int)
return biperm(perm, Val(blocklength1))
end
function biperm(perm, ::Val{BlockLength1}) where {BlockLength1}
# Check: BlockLength1 <= length(perm)
return blockedperm(Tuple(perm),(BlockLength1, length(perm) - BlockLength1))
end
length_codomain(t::AbstractBlockTuple{2}) = first(blocklength(t))
# Assume all dimensions are in the domain by default
length_codomain(t) = 0
and the use a combination of invperm
, biperm
, and length_codomain
in the contract code, for example:
biperm_a12_to_dest = biperm(invperm(biperm_dest_to_a12), length_codomain(axes(a_dest)))
function unmatricize( | ||
m::AbstractMatrix, axes_dest, biperm_dest_to_a12::AbstractBlockPermutation{2} | ||
) | ||
length(axes_dest) == length(biperm_dest_to_a12) || | ||
throw(ArgumentError("axes do not match permutation")) | ||
return unmatricize(FusionStyle(m), m, axes_dest, biperm_dest_to_a12) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think in the context of this function, the name biperm_dest_to_a12
is more confusing than helpful (that name only makes sense in the context of contract
but this function could be called for other purposes). I think it should be clear that biperm
means the permutation that should be performed on m
after it is reinterpreted as a length(axes_dest)
-dimensional array (unless I'm misunderstanding the conventions of this function, in which case we should change it to that convention).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same thing with the name axes_dest
, I think axes
is clear enough.
axes_dest, | ||
biperm_dest_to_a12::AbstractBlockPermutation{2}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comments as above about the naming.
return permuteblockeddims(a_perm, invperm(biperm)) | ||
blocked_axes = axes_dest[biperm_dest_to_a12] | ||
a12 = unmatricize(m, blocked_axes) | ||
biperm_a12_to_dest = invbiperm(biperm_dest_to_a12, axes_dest) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I lost track of the discussions we had about the conventions we want to use in this PR, I thought we had discussed that we would change the convention of unmatricize
so that the biperm that gets input would be taken "literally", i.e. it wouldn't need to be inverted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe as an alternative, we could change the convention of unmatricize
so that the axes that get input are the unpermuted axes, i.e. the axes corresponding directly to the memory ordering of the input matrix m
. Then, the bipermutation that gets input is just the bipermutation that needs to be done to get the desired output. I.e. it would be equivalent to:
function unmatricize(
style::FusionStyle,
m::AbstractMatrix,
ax,
biperm::AbstractBlockPermutation{2},
)
a = unmatricize(style, m, ax)
return permutedims(a, biperm)
end
This PR fixes bipermutation inversion in
contract
. A given bipermutation alone does not carry enough information to be inverted: the bipartition of the output must be specified. It turns out both permutation are needed incontract
stack at different times. I managed to pass only one inside each function and to invert it using information from other arguments.The logic of the code is now correct and able to handle arrays with bipartitions. We need bipermutations for both directions and be explicit which is which. I would be happy to improve the names I used.
Note that
BlockArrays
tests fail due to JuliaArrays/BlockArrays.jl#295.